import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python"
import sentencepiece
from transformers import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2

# 加载 tokenizer
llama_tokenizer = LlamaTokenizer.from_pretrained("llama2", legacy=False)
lulu_tokenizer = sentencepiece.SentencePieceProcessor()
lulu_tokenizer.load("lulu/tokenizer.model")


llama_mp = sentencepiece_model_pb2.ModelProto()
llama_mp.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
lulu_mp = sentencepiece_model_pb2.ModelProto()
lulu_mp.ParseFromString(lulu_tokenizer.serialized_model_proto())


print(len(llama_tokenizer), len(lulu_tokenizer))
print(llama_tokenizer.all_special_tokens)
print(llama_tokenizer.all_special_ids)
print(llama_tokenizer.special_tokens_map)


# 添加中文 tokenizer 到 llama tokenizer
llama_mp_token_set = set(p.piece for p in llama_mp.pieces)
print("Original llama:", len(llama_mp_token_set))
for p in lulu_mp.pieces:
    piece = p.piece
    if piece not in llama_mp_token_set:
        new_p = sentencepiece_model_pb2.ModelProto().SentencePiece()
        new_p.piece = piece
        new_p.score = 0
        llama_mp.pieces.append(new_p)
print("Merged llama:", len(llama_mp.pieces))


# 保存
with open("merged/tokenizer.model", "wb") as f:
    f.write(llama_mp.SerializeToString())
tokenizer = LlamaTokenizer(vocab_file="merged/tokenizer.model", legacy=False)
print(len(tokenizer))
tokenizer.save_pretrained("merged_hf")

